TBATS forecaster (multiple seasonality)#
TBATS extends BATS with Trigonometric seasonality (Fourier-style terms), which is especially useful when one of the seasonal periods is large (e.g., yearly seasonality with daily data).
This notebook implements a practical TBATS-style forecaster with an interface like:
TBATS(use_box_cox=..., box_cox_bounds=..., seasonal_periods=..., seasonal_harmonics=..., use_arma_errors=...)model = tbats.fit(y)forecast = model.forecast(steps)
Implementation note: the original TBATS model is state-space / exponential smoothing based. Here we implement a TBATS-style forecaster using:
explicit trend + Fourier seasonal design matrices, and
ARMA errors estimated via
statsmodels(SARIMAXwithd=0).
Model sketch (math)#
As with BATS, optionally apply Box–Cox and model \(x_t=g_\lambda(y_t)\).
TBATS uses a Fourier (trigonometric) seasonal representation for each seasonal period \(m\): $\(S_t^{(m)} = \sum_{k=1}^{K} \left(a_k\cos\left(\frac{2\pi k t}{m}\right) + b_k\sin\left(\frac{2\pi k t}{m}\right)\right).\)$
This uses \(2K\) parameters per seasonality instead of \(m-1\) seasonal dummies (BATS), which is much smaller when \(m\) is large.
import warnings
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os
import plotly.io as pio
from scipy import stats
import statsmodels.api as sm
warnings.filterwarnings("ignore", category=UserWarning)
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
pio.templates.default = "plotly_white"
rng = np.random.default_rng(7)
import numpy, pandas, scipy, statsmodels, plotly
print("numpy:", numpy.__version__)
print("pandas:", pandas.__version__)
print("scipy:", scipy.__version__)
print("statsmodels:", statsmodels.__version__)
print("plotly:", plotly.__version__)
numpy: 1.26.2
pandas: 2.1.3
scipy: 1.15.0
statsmodels: 0.14.4
plotly: 6.5.2
class BoxCoxTransformer:
def __init__(self, use_box_cox: bool, box_cox_bounds: tuple[float, float] = (0.0, 1.0)):
self.use_box_cox = bool(use_box_cox)
self.box_cox_bounds = tuple(float(v) for v in box_cox_bounds)
self.shift_: float = 0.0
self.lambda_: float | None = None
def fit(self, y: np.ndarray) -> "BoxCoxTransformer":
y = np.asarray(y, dtype=float)
if not self.use_box_cox:
self.shift_ = 0.0
self.lambda_ = None
return self
min_y = float(np.min(y))
self.shift_ = 0.0 if min_y > 0.0 else (1.0 - min_y)
y_pos = y + self.shift_
if np.any(y_pos <= 0.0):
raise ValueError("Box-Cox requires strictly positive data (even after shift).")
lo, hi = self.box_cox_bounds
self.lambda_ = float(stats.boxcox_normmax(y_pos, brack=(lo, hi), method="mle"))
return self
def transform(self, y: np.ndarray) -> np.ndarray:
y = np.asarray(y, dtype=float)
if not self.use_box_cox:
return y.copy()
if self.lambda_ is None:
raise RuntimeError("Call fit() before transform().")
y_pos = y + self.shift_
if np.any(y_pos <= 0.0):
raise ValueError("Box-Cox requires strictly positive data (even after shift).")
lmbda = float(self.lambda_)
if abs(lmbda) < 1e-10:
return np.log(y_pos)
return (np.power(y_pos, lmbda) - 1.0) / lmbda
def inverse_transform(self, x: np.ndarray) -> np.ndarray:
x = np.asarray(x, dtype=float)
if not self.use_box_cox:
return x.copy()
if self.lambda_ is None:
raise RuntimeError("Call fit() before inverse_transform().")
lmbda = float(self.lambda_)
if abs(lmbda) < 1e-10:
y_pos = np.exp(x)
else:
y_pos = np.power(lmbda * x + 1.0, 1.0 / lmbda)
return y_pos - self.shift_
def _acf(x: np.ndarray, max_lag: int) -> tuple[np.ndarray, np.ndarray]:
x = np.asarray(x, dtype=float)
x = x - x.mean()
denom = float(np.dot(x, x))
lags = np.arange(max_lag + 1)
values = np.zeros(max_lag + 1)
values[0] = 1.0
if denom == 0.0:
return lags, values
for k in range(1, max_lag + 1):
values[k] = float(np.dot(x[k:], x[:-k]) / denom)
return lags, values
def trend_feature(t: np.ndarray, *, use_damped: bool, damped_phi: float) -> np.ndarray:
t = np.asarray(t, dtype=float)
if not use_damped:
return t
phi = float(damped_phi)
if not (0.0 < phi < 1.0):
raise ValueError("damped_phi must be in (0, 1)")
return (1.0 - np.power(phi, t)) / (1.0 - phi)
def fourier_terms(t: np.ndarray, period: int, K: int) -> np.ndarray:
t = np.asarray(t, dtype=float)
period = int(period)
K = int(K)
if period <= 1 or K <= 0:
return np.zeros((t.size, 0), dtype=float)
K = min(K, period // 2)
cols: list[np.ndarray] = []
for k in range(1, K + 1):
ang = 2.0 * np.pi * k * t / period
cols.append(np.cos(ang))
cols.append(np.sin(ang))
return np.column_stack(cols).astype(float)
def tbats_design_matrix(
t: np.ndarray,
*,
use_trend: bool,
use_damped_trend: bool,
damped_trend_phi: float,
seasonal_periods: list[int] | None,
seasonal_harmonics: list[int] | None,
) -> np.ndarray:
t = np.asarray(t, dtype=int)
cols = [np.ones((t.size, 1), dtype=float)]
if use_trend:
cols.append(trend_feature(t.astype(float), use_damped=use_damped_trend, damped_phi=damped_trend_phi).reshape(-1, 1))
if seasonal_periods:
periods = [int(m) for m in seasonal_periods]
if seasonal_harmonics is None:
harmonics = [min(10, m // 2) for m in periods]
else:
harmonics = [int(k) for k in seasonal_harmonics]
if len(harmonics) != len(periods):
raise ValueError("seasonal_harmonics must match seasonal_periods length")
for m, K in zip(periods, harmonics):
cols.append(fourier_terms(t.astype(float), period=m, K=K))
return np.concatenate(cols, axis=1)
class TBATSModel:
def __init__(
self,
*,
results,
transformer: BoxCoxTransformer,
use_trend: bool,
use_damped_trend: bool,
damped_trend_phi: float,
seasonal_periods: list[int] | None,
seasonal_harmonics: list[int] | None,
y_index,
):
self.results = results
self.transformer = transformer
self.use_trend = use_trend
self.use_damped_trend = use_damped_trend
self.damped_trend_phi = float(damped_trend_phi)
self.seasonal_periods = seasonal_periods
self.seasonal_harmonics = seasonal_harmonics
self.y_index = y_index
@property
def n_obs(self) -> int:
return int(self.results.nobs)
def fitted_values(self) -> np.ndarray:
fitted_x = np.asarray(self.results.fittedvalues, dtype=float)
return self.transformer.inverse_transform(fitted_x)
def residuals(self) -> np.ndarray:
return np.asarray(self.results.resid, dtype=float)
def forecast(self, steps: int, *, alpha: float = 0.05) -> dict[str, np.ndarray]:
steps = int(steps)
t_future = np.arange(self.n_obs, self.n_obs + steps)
X_future = tbats_design_matrix(
t_future,
use_trend=self.use_trend,
use_damped_trend=self.use_damped_trend,
damped_trend_phi=self.damped_trend_phi,
seasonal_periods=self.seasonal_periods,
seasonal_harmonics=self.seasonal_harmonics,
)
fcst = self.results.get_forecast(steps=steps, exog=X_future)
mean_x = np.asarray(fcst.predicted_mean, dtype=float)
ci = fcst.conf_int(alpha=alpha)
ci_np = np.asarray(ci)
lower_x = ci_np[:, 0]
upper_x = ci_np[:, 1]
mean_y = self.transformer.inverse_transform(mean_x)
lower_y = self.transformer.inverse_transform(lower_x)
upper_y = self.transformer.inverse_transform(upper_x)
return {"mean": mean_y, "lower": lower_y, "upper": upper_y}
class TBATS:
def __init__(
self,
*,
use_box_cox: bool = False,
box_cox_bounds: tuple[float, float] = (0.0, 1.0),
use_trend: bool = True,
use_damped_trend: bool = False,
damped_trend_phi: float = 0.98,
seasonal_periods: list[int] | None = None,
seasonal_harmonics: list[int] | None = None,
use_arma_errors: bool = True,
arma_order: tuple[int, int] | None = (1, 1),
max_arma_order: int = 1,
show_warnings: bool = True,
):
self.use_box_cox = bool(use_box_cox)
self.box_cox_bounds = tuple(float(v) for v in box_cox_bounds)
self.use_trend = bool(use_trend)
self.use_damped_trend = bool(use_damped_trend)
self.damped_trend_phi = float(damped_trend_phi)
self.seasonal_periods = None if seasonal_periods is None else [int(m) for m in seasonal_periods]
self.seasonal_harmonics = None if seasonal_harmonics is None else [int(k) for k in seasonal_harmonics]
self.use_arma_errors = bool(use_arma_errors)
self.arma_order = None if arma_order is None else (int(arma_order[0]), int(arma_order[1]))
self.max_arma_order = int(max_arma_order)
self.show_warnings = bool(show_warnings)
def _fit_sarimax(self, y_x: np.ndarray, X: np.ndarray, order: tuple[int, int]) -> tuple[object, float]:
p, q = order
res = sm.tsa.SARIMAX(
y_x,
exog=X,
order=(p, 0, q),
trend="n",
enforce_stationarity=True,
enforce_invertibility=True,
).fit(disp=False, method="lbfgs", maxiter=300)
return res, float(res.aic)
def _select_arma_order(self, y_x: np.ndarray, X: np.ndarray) -> tuple[int, int]:
candidates = []
for p in range(self.max_arma_order + 1):
for q in range(self.max_arma_order + 1):
candidates.append((p, q))
best_order = (0, 0)
best_aic = np.inf
for order in candidates:
try:
_, aic = self._fit_sarimax(y_x, X, order)
except Exception:
continue
if aic < best_aic:
best_aic = aic
best_order = order
if best_aic == np.inf:
raise RuntimeError("Failed to fit any ARMA(p,q) candidate.")
return best_order
def fit(self, y) -> TBATSModel:
if isinstance(y, pd.Series):
y_index = y.index
y_np = y.to_numpy(dtype=float)
else:
y_index = None
y_np = np.asarray(y, dtype=float)
t = np.arange(y_np.size)
X = tbats_design_matrix(
t,
use_trend=self.use_trend,
use_damped_trend=self.use_damped_trend,
damped_trend_phi=self.damped_trend_phi,
seasonal_periods=self.seasonal_periods,
seasonal_harmonics=self.seasonal_harmonics,
)
transformer = BoxCoxTransformer(self.use_box_cox, box_cox_bounds=self.box_cox_bounds).fit(y_np)
y_x = transformer.transform(y_np)
if not self.use_arma_errors:
chosen_order = (0, 0)
elif self.arma_order is not None:
chosen_order = self.arma_order
else:
chosen_order = self._select_arma_order(y_x, X)
res, aic = self._fit_sarimax(y_x, X, chosen_order)
if self.show_warnings:
print(f"Chosen ARMA(p,q) = {chosen_order}, AIC = {aic:.2f}")
return TBATSModel(
results=res,
transformer=transformer,
use_trend=self.use_trend,
use_damped_trend=self.use_damped_trend,
damped_trend_phi=self.damped_trend_phi,
seasonal_periods=self.seasonal_periods,
seasonal_harmonics=self.seasonal_harmonics,
y_index=y_index,
)
Demo: long seasonality (weekly + yearly)#
We’ll simulate daily data with two seasonalities:
weekly (\(m_1=7\))
yearly (\(m_2=365\))
A BATS dummy-season model would need roughly \((7-1)+(365-1)=370\) seasonal parameters (plus trend). TBATS can represent the same seasonal patterns with a small number of harmonics.
def simulate_arma11(n: int, *, phi: float, theta: float, sigma: float, rng: np.random.Generator) -> np.ndarray:
eps = rng.normal(0.0, sigma, size=n)
u = np.zeros(n)
for t in range(n):
ar = phi * u[t - 1] if t - 1 >= 0 else 0.0
ma = theta * eps[t - 1] if t - 1 >= 0 else 0.0
u[t] = ar + eps[t] + ma
return u
n = 4 * 365
idx = pd.date_range("2018-01-01", periods=n, freq="D")
t = np.arange(n)
weekly = 1.0 * np.sin(2 * np.pi * t / 7) + 0.3 * np.cos(2 * np.pi * t / 7)
yearly = 3.5 * np.sin(2 * np.pi * t / 365) + 1.7 * np.cos(2 * np.pi * t / 365)
trend = 0.003 * t
noise = simulate_arma11(n, phi=0.5, theta=0.3, sigma=0.9, rng=rng)
y = 80.0 + trend + weekly + yearly + noise
y = pd.Series(y, index=idx, name="y")
fig = go.Figure()
fig.add_trace(go.Scatter(x=y.index, y=y, name="y", line=dict(color="black")))
fig.update_layout(title="Synthetic long-seasonality series", xaxis_title="date", yaxis_title="value")
fig.show()
# Compare parameter counts: BATS dummies vs TBATS Fourier
m1, m2 = 7, 365
bats_seasonal_params = (m1 - 1) + (m2 - 1)
K_weekly = 3
K_yearly = 10
tbats_seasonal_params = 2 * K_weekly + 2 * K_yearly
print("BATS seasonal parameters (dummies):", bats_seasonal_params)
print("TBATS seasonal parameters (Fourier):", tbats_seasonal_params)
BATS seasonal parameters (dummies): 370
TBATS seasonal parameters (Fourier): 26
# Train/test split + fit TBATS
h = 90
y_train = y.iloc[:-h]
y_test = y.iloc[-h:]
tbats = TBATS(
use_box_cox=False,
box_cox_bounds=(0.0, 1.0),
use_trend=True,
use_damped_trend=False,
seasonal_periods=[7, 365],
seasonal_harmonics=[K_weekly, K_yearly],
use_arma_errors=True,
arma_order=(1, 1),
show_warnings=True,
)
model = tbats.fit(y_train)
fcst = model.forecast(h)
fitted = pd.Series(model.fitted_values(), index=y_train.index)
pred_mean = pd.Series(fcst["mean"], index=y_test.index)
pred_lower = pd.Series(fcst["lower"], index=y_test.index)
pred_upper = pd.Series(fcst["upper"], index=y_test.index)
fig = go.Figure()
fig.add_trace(go.Scatter(x=y_train.index, y=y_train, name="train", line=dict(color="rgba(0,0,0,0.35)")))
fig.add_trace(go.Scatter(x=y_train.index, y=fitted, name="fitted", line=dict(color="#59A14F")))
fig.add_trace(go.Scatter(x=y_test.index, y=y_test, name="test", line=dict(color="black")))
fig.add_trace(go.Scatter(x=y_test.index, y=pred_upper, line=dict(width=0), showlegend=False))
fig.add_trace(
go.Scatter(
x=y_test.index,
y=pred_lower,
fill="tonexty",
fillcolor="rgba(89,161,79,0.18)",
line=dict(width=0),
name="95% interval (approx)",
)
)
fig.add_trace(go.Scatter(x=y_test.index, y=pred_mean, name="forecast mean", line=dict(color="#E15759")))
fig.update_layout(title="TBATS forecast on long-seasonality series", xaxis_title="date", yaxis_title="value")
fig.show()
Chosen ARMA(p,q) = (1, 1), AIC = 3512.80
# Residual diagnostics (in transformed space)
resid = model.residuals()
warmup = 10
resid_use = resid[warmup:]
print("residual mean:", float(resid_use.mean()))
print("residual std:", float(resid_use.std(ddof=1)))
print("Jarque-Bera:", stats.jarque_bera(resid_use))
lags, acf_vals = _acf(resid_use, max_lag=30)
bound = 1.96 / np.sqrt(resid_use.size)
# QQ data
nq = resid_use.size
p = (np.arange(1, nq + 1) - 0.5) / nq
theoretical = stats.norm.ppf(p)
sample_q = np.sort((resid_use - resid_use.mean()) / resid_use.std(ddof=1))
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=("Residuals over time", "Residual histogram", "Residual ACF", "QQ plot (std residuals)"),
)
fig.add_trace(go.Scatter(x=y_train.index[warmup:], y=resid_use, name="residuals", line=dict(color="#59A14F")), row=1, col=1)
fig.add_hline(y=0, line=dict(color="black", dash="dash"), row=1, col=1)
fig.add_trace(go.Histogram(x=resid_use, nbinsx=30, name="hist", marker_color="#59A14F"), row=1, col=2)
fig.add_trace(go.Bar(x=lags, y=acf_vals, name="ACF(resid)", marker_color="#59A14F"), row=2, col=1)
fig.add_trace(go.Scatter(x=[0, lags.max()], y=[bound, bound], mode="lines", line=dict(color="gray", dash="dash"), showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=[0, lags.max()], y=[-bound, -bound], mode="lines", line=dict(color="gray", dash="dash"), showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=theoretical, y=sample_q, mode="markers", name="QQ", marker=dict(color="#59A14F")), row=2, col=2)
fig.add_trace(
go.Scatter(x=[theoretical.min(), theoretical.max()], y=[theoretical.min(), theoretical.max()], mode="lines", line=dict(color="black", dash="dash"), showlegend=False),
row=2,
col=2,
)
fig.update_layout(height=750, title="TBATS residual diagnostics")
fig.show()
residual mean: 5.158801683041921e-05
residual std: 0.8543250028588507
Jarque-Bera: SignificanceResult(statistic=0.5576490029991557, pvalue=0.7566726864848627)